import torch
import copy
import torch.nn as nn
from tqdm import tqdm
import evaluate as eval_metric
from transformers.optimization import get_scheduler
import math
import itertools

def create_scheduler(args, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
    """
    Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
    passed as an argument.

    Args:
        num_training_steps (int): The number of training steps to do.
    """
    lr_scheduler = get_scheduler(
        args.lr_scheduler_type,
        optimizer=optimizer if optimizer is None else optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
        scheduler_specific_kwargs={},
    )
    return lr_scheduler

class Model():
    def __init__(self, args, model, dist):
        super(Model, self).__init__()
        self.args = args
        self.model = model
        self.dist = dist
        self.create_opt()

    def create_opt(self):
        self.trainable_params_A = []
        self.trainable_params_B = []
        self.trainable_params = []
        self.z = []
        self.m = []
        num_trainable_params = 0
        all_param = 0
        i = 0
        for name, param in self.model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params

            if param.requires_grad:
                self.trainable_params.append(param)
                if not ('lora_C' in name or 'lora_D' in name):
                    self.trainable_params_A.append(param)
                    new_param = torch.zeros(param.shape)
                    self.m.append(new_param.cuda())
                else:
                    self.trainable_params_B.append(param)
                    new_param = torch.zeros(param.shape)
                    self.z.append(new_param.cuda())
                num_trainable_params += num_params
                self.trainable_params.append(param)

        if self.args.rank % self.args.world_size == 0:
            print(
                f"trainable params: {num_trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * num_trainable_params / all_param}")

    def train(self, train_dataset, eval_dataset, train_lower):
        self.model = self.model.cuda()
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True)
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.args.batch_size)
        train_low_loader = itertools.cycle(torch.utils.data.DataLoader(train_lower, batch_size=self.args.batch_size))
        # self.optimizer_inner = torch.optim.AdamW(self.trainable_params_B, lr=self.args.lr_A, weight_decay=1e-4)  # weight_decay=1e-4

        self.optimizer_outer = torch.optim.AdamW(self.trainable_params_A, lr=self.args.lr_B, weight_decay=1e-4)
        self.optimizer_inner = torch.optim.SGD( self.trainable_params_B, lr=self.args.lr_A, weight_decay=1e-4)

        num_training_steps = self.args.com_rounds * self.args.com_interval
        # lr_scheduler_inner = create_scheduler(self.args, num_training_steps, self.optimizer_inner)
        lr_scheduler_outer = create_scheduler(self.args, num_training_steps, self.optimizer_outer)

        metric = eval_metric.load("glue", self.args.dataset)
        loss_list = []
        if self.args.dataset == "mnli":
            test_results_list = {"mnli-m/acc": [], "mnli-mm/acc":[]}
        else:
            test_results_list = {}
        cut_round = 0
        self.model.train()
        for epoch in range(self.args.num_epochs):

            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.args.num_epochs}", disable=self.args.rank%self.args.world_size!=0)
            for batch_idx, batch in enumerate(progress_bar):

                if cut_round >= self.args.com_rounds:
                    break
                input_ids = batch["input_ids"].cuda()
                attention_mask = batch["attention_mask"].cuda()
                labels = batch["label"].cuda()
                self.optimizer_outer.zero_grad()
                # loss = self.learn1(input_ids, attention_mask, labels)
                loss = self.learn(input_ids, attention_mask, labels, train_low_loader)

                loss_list.append(round(loss.item(), 4))
                torch.nn.utils.clip_grad_norm_(self.trainable_params_A, self.args.max_grad_norm)
                lr_scheduler_outer.step()
                self.optimizer_outer.step()

                torch.cuda.empty_cache()

                progress_bar.set_postfix({"Rank": self.args.rank, "Train loss": loss.item(), "lr": lr_scheduler_outer.get_lr()[0]})
                # print(f"Rank: {self.args.rank}, Train loss: {loss.item()}, lr: {lr_scheduler_outer.get_lr()}")
                if batch_idx % self.args.com_interval == 0:
                    # average the upper-level variables
                    for n, p in self.model.named_parameters():
                        if p.requires_grad == True:
                            if not ('lora_C' in n or 'lora_D' in n):
                                self.dist.reduce(p.data, dst=0, op=self.dist.ReduceOp.SUM)
                                p.data /= self.args.world_size
                                self.dist.broadcast(p.data, src=0)
                    cut_round += 1
            # Add a synchronization barrier before dist.all_gather
            self.dist.barrier()
            eval_results = self.evaluate(eval_loader, metric)
            print(f'rank: {self.args.rank}, {eval_results}')
            self.dist.barrier()
            # Convert the evaluation results to a tensor
            if self.args.dataset == "mnli":
                eval_results_tensor = torch.tensor([eval_results["accuracy"], eval_results["accuracy"]],
                                                   dtype=torch.float32).cuda()
            else:
                eval_results_tensor = torch.tensor(list(eval_results.values()), dtype=torch.float32).cuda()

            # Gather evaluation results from all clients
            if self.args.dataset == "mnli":
                eval_results_list = [torch.zeros_like(eval_results_tensor) for _ in range(self.args.world_size)]
            else:
                eval_results_list = [torch.zeros_like(eval_results_tensor) for _ in range(self.args.world_size)]

            self.dist.all_gather(eval_results_list, eval_results_tensor)

            if self.args.rank % self.args.world_size == 0:
                if self.args.dataset == "mnli":
                    avg_matched_acc = torch.mean(torch.tensor([result[0] for result in eval_results_list])).item()
                    avg_mismatched_acc = torch.mean(torch.tensor([result[1] for result in eval_results_list])).item()
                    avg_eval_results = {"mnli-m/acc": avg_matched_acc, "mnli-mm/acc": avg_mismatched_acc}
                    test_results_list["mnli-m/acc"].append(avg_matched_acc)
                    test_results_list["mnli-mm/acc"].append(avg_mismatched_acc)
                else:
                    avg_eval_results = {}
                    for i, key in enumerate(eval_results.keys()):
                        avg_value = round(torch.mean(torch.tensor([result[i] for result in eval_results_list])).item(), 4)
                        avg_eval_results[key] = avg_value
                        if key not in test_results_list:
                            test_results_list[key] = []
                        test_results_list[key].append(avg_value)


                print(f"Epoch {epoch + 1}/{self.args.num_epochs}:")
                # print(f"Average evaluation results: {avg_eval_results}")
                print(f"Evaluation results list: {test_results_list}")
                torch.save((test_results_list, loss_list), self.args.save_path + '.pkl')
                with open(self.args.save_path + '.txt', 'w') as f:
                    f.write(str({'Exp config': str(self.args), 'Average evaluation results': str(test_results_list)}))
            if cut_round >= self.args.com_rounds:
                if self.args.rank % self.args.world_size == 0:
                    print(f"Average evaluation results: {test_results_list}")
                break


    def evaluate(self, dataloader, metric):
        for n, p in self.model.named_parameters():
            if p.requires_grad == True:
                if not ('lora_C' in n or 'lora_D' in n):
                    self.dist.reduce(p.data, dst=0, op=self.dist.ReduceOp.SUM)
                    p.data /= self.args.world_size
                    self.dist.broadcast(p.data, src=0)
        self.model.eval()
        progress_bar = tqdm(dataloader, desc="Evaluation", unit="batch")

        for batch in progress_bar:
            input_ids = batch["input_ids"].cuda()
            attention_mask = batch["attention_mask"].cuda()
            labels = batch["label"].cuda()
            outputs = self.model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            metric.add_batch(predictions=predictions, references=labels)

        results = metric.compute()
        return results

    def regularizaton(self):
        value = 0
        value += sum([x.data.detach().norm() for x in self.trainable_params_B])
        return 0.5*value

    def learn(self, input_ids, attention_mask, labels, train_low_loader):
        model_copy = copy.deepcopy(self.model)
        trainable_params_y = []
        trainable_params_x = []
        for n, p in model_copy.named_parameters():
            if p.requires_grad == True:
                if 'lora_C' in n or 'lora_D' in n:
                    trainable_params_y.append(p)
                else:
                    trainable_params_x.append(p)
        optimizer_2 = torch.optim.SGD( trainable_params_y, lr=self.args.lr_A, weight_decay=1e-4)

        # one_step learning for inner_loops
        batch = next(iter(train_low_loader))
        input_ids_in = batch["input_ids"].cuda()
        attention_mask_in = batch["attention_mask"].cuda()
        labels_in = batch["label"].cuda()
        outputs = model_copy(input_ids_in, attention_mask=attention_mask_in, labels=labels_in)

        optimizer_2.zero_grad()
        loss_2 = outputs.loss
        loss_2.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params_y, self.args.max_grad_norm)
        optimizer_2.step()
        model_copy.zero_grad()
        outputs = model_copy(input_ids, attention_mask=attention_mask, labels=labels)
        loss_upper = outputs.loss
        F_x = torch.autograd.grad(loss_upper, trainable_params_x, retain_graph=True)
        F_y = torch.autograd.grad(loss_upper, trainable_params_y)

        batch = next(iter(train_low_loader))
        input_ids_in = batch["input_ids"].cuda()
        attention_mask_in = batch["attention_mask"].cuda()
        labels_in = batch["label"].cuda()

        outputs_in = self.model(input_ids_in, attention_mask=attention_mask_in, labels=labels_in)
        loss_in = outputs_in.loss
        G_y = torch.autograd.grad(loss_in, self.trainable_params_B, create_graph=True)

        G_xy = torch.autograd.grad(G_y, self.trainable_params_A, grad_outputs=F_y)
        self.optimizer_outer.zero_grad()

        torch.cuda.empty_cache()
        for p, f_x, g_xy in zip(self.trainable_params_A, F_x, G_xy):

            p.grad = f_x - self.args.lr_A * g_xy
        i = 0
        # Copy the updated lower-level variables to the original model
        for n, p in self.model.named_parameters():
            if p.requires_grad == True:
                if 'lora_C' in n or 'lora_D' in n:
                    p.data = trainable_params_y[i].detach().clone()
                    i += 1
        return loss_in



    def learn1(self, input_ids, attention_mask, labels):
        self.optimizer_inner.zero_grad()
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        self.optimizer_inner.step()
        return loss
    #
